#include "mpi.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>

#define MASTER 0

void get_neighbors(int rank, int N, int *up, int *down, int *left, int *right) {
    int row = rank / N;
    int col = rank % N;
    
    *up = (row > 0) ? rank - N : -1;
    *down = (row < N - 1) ? rank + N : -1;
    *left = (col > 0) ? rank - 1 : -1;
    *right = (col < N - 1) ? rank + 1 : -1;
}

void broadcast_params(int rank, int N, int *params) {
    int row = rank / N;
    int col = rank % N;
    int up, down, left, right;
    get_neighbors(rank, N, &up, &down, &left, &right);
    for (int dist = 0; dist <= 2 * (N - 1); dist++) {
        int my_dist = row + col;
        
        if (my_dist == dist && rank != MASTER) {
            if (col > 0 && left != -1) {
                MPI_Recv(params, 2, MPI_INT, left, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
            } else if (col == 0 && up != -1) {
                MPI_Recv(params, 2, MPI_INT, up, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
            }
        }
        MPI_Barrier(MPI_COMM_WORLD);
        
        if (my_dist == dist) {
            if (right != -1) {
                MPI_Send(params, 2, MPI_INT, right, 2, MPI_COMM_WORLD);
            }
            if (down != -1 && col == 0) {
                MPI_Send(params, 2, MPI_INT, down, 2, MPI_COMM_WORLD);
            }
        }
        
        MPI_Barrier(MPI_COMM_WORLD);
    }
}
void broadcast_vector(int rank, int N, int *vector, int size) {
    int row = rank / N;
    int col = rank % N;
    int up, down, left, right;
    get_neighbors(rank, N, &up, &down, &left, &right);
    
    for (int dist = 0; dist <= 2 * (N - 1); dist++) {
        int my_dist = row + col;
        
        if (my_dist == dist && rank != MASTER) {
            if (col > 0 && left != -1) {
                MPI_Recv(vector, size, MPI_INT, left, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
            } else if (col == 0 && up != -1) {
                MPI_Recv(vector, size, MPI_INT, up, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
            }
        }
        
        MPI_Barrier(MPI_COMM_WORLD);
        
        if (my_dist == dist) {
            if (right != -1) {
                MPI_Send(vector, size, MPI_INT, right, 0, MPI_COMM_WORLD);
            }
            if (down != -1 && col == 0) {
                MPI_Send(vector, size, MPI_INT, down, 0, MPI_COMM_WORLD);
            }
        }
        
        MPI_Barrier(MPI_COMM_WORLD);
    }
}


void reduce_counts(int rank, int N, int *count) {
    int row = rank / N;
    int col = rank % N;
    int up, down, left, right;
    get_neighbors(rank, N, &up, &down, &left, &right);
    for (int dist = 2 * (N - 1); dist >= 0; dist--) {
        int my_dist = row + col;
        
        if (my_dist == dist) {
            if (right != -1) {
                int recv_count;
                MPI_Recv(&recv_count, 1, MPI_INT, right, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
                *count += recv_count;
            }
            if (down != -1 && col == 0) {
                int recv_count;
                MPI_Recv(&recv_count, 1, MPI_INT, down, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
                *count += recv_count;
            }
        }
        
        MPI_Barrier(MPI_COMM_WORLD);
        
        if (my_dist == dist && rank != MASTER) {
            if (col > 0 && left != -1) {
                MPI_Send(count, 1, MPI_INT, left, 1, MPI_COMM_WORLD);
            } else if (col == 0 && up != -1) {
                MPI_Send(count, 1, MPI_INT, up, 1, MPI_COMM_WORLD);
            }
        }
        
        MPI_Barrier(MPI_COMM_WORLD);
    }
}

int main(int argc, char *argv[]) {
    int procs, rank;
    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &procs);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);

    if (rank == MASTER) {
        if (argc != 3) {
            fprintf(stderr, "Usage: mpirun -np P %s E k\n", argv[0]);
            MPI_Abort(MPI_COMM_WORLD, 1);
        }
    }
    
    int P = procs;
    int N = (int)sqrt((double)P);
    
    if (rank == MASTER && N * N != P) {
        fprintf(stderr, "Number of processes must be a perfect square\n");
        MPI_Abort(MPI_COMM_WORLD, 1);
    }
    
    int params[2] = {0, 0};
    if (rank == MASTER) {
        params[0] = atoi(argv[1]);  // E
        params[1] = atoi(argv[2]);  // k
    }
    
    broadcast_params(rank, N, params);
    int E = params[0];
    int k = params[1];
    
    int total_size = k * P;
    int *full_vector = (int*)malloc(total_size * sizeof(int));
    if (full_vector == NULL) {
        fprintf(stderr, "Error: Failed to allocate memory for vector\n");
        MPI_Abort(MPI_COMM_WORLD, 1);
    }
    
    if (rank == MASTER) {
        for (int i = 0; i < total_size; i++) {
            full_vector[i] = (i + 3) % 7;
        }
    }
    
    broadcast_vector(rank, N, full_vector, total_size);
    
    int chunk_size = k;
    int start_idx = rank * chunk_size;
    int local_count = 0;
    
    for (int i = 0; i < chunk_size; i++) {
        if (full_vector[start_idx + i] == E) {
            local_count++;
        }
    }
    
    free(full_vector);
    reduce_counts(rank, N, &local_count);

    if (rank == MASTER) {
        printf("%d\n", local_count);
    }
    
    MPI_Finalize();
    return 0;
}
